# -*- coding: utf-8 -*-
"""
Created on Tue Aug 29 17:35:23 2023
@author: walsworthlab
"""
import numpy as np

file_dir='//10.229.62.137/data/Experiment folders/ML Inverse Problem\ML Current Inversion Paper (2024)/training and validation data/type2_50um/'
save_dir='//10.229.62.137/data/Experiment folders/ML Inverse Problem\ML Current Inversion Paper (2024)/training and validation data/type2_50um/temp/'

file_size=512
num_files=154
start_ind=0

count_b0=0

# batch_str='thin_64'
# save_str='thin_64'

batch_str='TrainingDat_256'
save_str='TrainingDat_256'

# batch_str='blend3_256TrainingDat_256'
# save_str='TrainingDat_256'

file_in=file_dir+batch_str+'_in_'+str(155).zfill(3)+'.npy'
data_in0=np.load(file_in)

file_out=file_dir+batch_str+'_out_'+str(155).zfill(3)+'.npy'
data_out0=np.load(file_out)

file_z=file_dir+'standoff_distances_'+str(155).zfill(3)+'.npy'
data_z0=np.load(file_z)


for nf in range(start_ind,num_files):
    print(nf)
    file_in=file_dir+batch_str+'_in_'+str(nf).zfill(3)+'.npy'
    data_in=np.load(file_in)
    
    file_out=file_dir+batch_str+'_out_'+str(nf).zfill(3)+'.npy'
    data_out=np.load(file_out)

    file_z=file_dir+'standoff_distances_'+str(nf).zfill(3)+'.npy'
    data_z=np.load(file_z)
    
    max_vals_b=np.max(abs(data_in),axis=1)
    max_vals_b=np.max(max_vals_b,axis=1)
    max_vals_b=np.max(max_vals_b,axis=1)
    
    #max_vals_j=np.max(abs(data_out),axis=1)
    #max_vals_j=np.max(max_vals_j,axis=1)
    #max_vals_j=np.max(max_vals_j,axis=1)
    
    b0=np.where(max_vals_b==0)[0]
    #j0=np.where(max_vals_j==0)[0]
    #j100=np.where(max_vals_j>100)[0]
    
    #count_b0+=len(b0)
    #count_j0+=len(j0)
    #count_j100+=len(j100)
    
    for ind in range(len(b0)):
        data_in[b0[ind],:,:,:]=data_in0[count_b0,:,:,:]
        data_out[b0[ind],:,:,:]=data_out0[count_b0,:,:,:]
        data_z[b0[ind]]=data_z0[count_b0]
        count_b0+=1
        
    file_in=save_dir+save_str+'_in_'+str(nf).zfill(3)+'.npy'
    file_out=save_dir+save_str+'_out_'+str(nf).zfill(3)+'.npy'
    file_z=save_dir+'standoff_distances_'+str(nf).zfill(3)+'.npy'

    np.save(file_in,data_in)
    np.save(file_out,data_out)
    np.save(file_z,data_z)

print(count_b0)
